e0c27c
@@ -50,7 +50,9 @@
 import org.apache.hadoop.hive.ql.exec.FileSinkOperator;
 import org.apache.hadoop.hive.ql.io.HiveOutputFormat;
 import org.apache.hadoop.hive.serde.serdeConstants;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping;
+import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
 import org.apache.hadoop.io.Writable;
@@ -128,28 +130,37 @@
     final List<DimensionSchema> dimensions = new ArrayList<>();
     ImmutableList.Builder<AggregatorFactory> aggregatorFactoryBuilder = ImmutableList.builder();
     for (int i = 0; i < columnTypes.size(); i++) {
-      TypeInfo f = columnTypes.get(i);
-      assert f.getCategory() == ObjectInspector.Category.PRIMITIVE;
+      PrimitiveTypeInfo f = (PrimitiveTypeInfo) columnTypes.get(i);
       AggregatorFactory af;
-      switch (f.getTypeName()) {
-        case serdeConstants.TINYINT_TYPE_NAME:
-        case serdeConstants.SMALLINT_TYPE_NAME:
-        case serdeConstants.INT_TYPE_NAME:
-        case serdeConstants.BIGINT_TYPE_NAME:
+      switch (f.getPrimitiveCategory()) {
+        case BYTE:
+        case SHORT:
+        case INT:
+        case LONG:
           af = new LongSumAggregatorFactory(columnNames.get(i), columnNames.get(i));
           break;
-        case serdeConstants.FLOAT_TYPE_NAME:
-        case serdeConstants.DOUBLE_TYPE_NAME:
-        case serdeConstants.DECIMAL_TYPE_NAME:
+        case FLOAT:
+        case DOUBLE:
+        case DECIMAL:
           af = new DoubleSumAggregatorFactory(columnNames.get(i), columnNames.get(i));
           break;
-        default:
-          // Dimension or timestamp
-          String columnName = columnNames.get(i);
-          if (!columnName.equals(DruidTable.DEFAULT_TIMESTAMP_COLUMN) && !columnName
+        case TIMESTAMP:
+          String tColumnName = columnNames.get(i);
+          if (!tColumnName.equals(DruidTable.DEFAULT_TIMESTAMP_COLUMN) && !tColumnName
                   .equals(Constants.DRUID_TIMESTAMP_GRANULARITY_COL_NAME)) {
-            dimensions.add(new StringDimensionSchema(columnName));
+            throw new IOException("Dimension " + tColumnName + " does not have STRING type: " +
+                    f.getPrimitiveCategory());
+          }
+          continue;
+        default:
+          // Dimension
+          String dColumnName = columnNames.get(i);
+          if (PrimitiveObjectInspectorUtils.getPrimitiveGrouping(f.getPrimitiveCategory()) !=
+                  PrimitiveGrouping.STRING_GROUP) {
+            throw new IOException("Dimension " + dColumnName + " does not have STRING type: " +
+                    f.getPrimitiveCategory());
           }
+          dimensions.add(new StringDimensionSchema(dColumnName));
           continue;
       }
       aggregatorFactoryBuilder.add(af);
